1
从急切模式操作到基于块的并行计算
AI023Lesson 3
00:00

PyTorch 急切模式 过渡到 Triton 要求我们从将张量视为整体对象,转变为将其看作一组离散且可管理的 或区块。

1. PyTorch 与 Triton 张量对比

必须明确区分 Triton 张量PyTorch 张量。PyTorch 张量是一个 主机端的 Python 对象 封装了形状、数据类型、设备、步长和存储元数据。相比之下,Triton 处理的是特定内存块内的 原始数据指针 ,从而实现更底层的优化。

2. 急切模式的瓶颈

在标准的急切执行中,每次操作(例如加法后接 ReLU)都需要一次独立的内核启动和一次 全局内存往返传输。这是现代 GPU 计算中的主要瓶颈。Triton 通过在单个内核中融合多个操作来克服这一问题,该内核直接在片上内存中处理数据块(例如 128、256 或 512 个元素)。 融合 操作,这些操作在一个内核中处理数据块(例如 128、256 或 512 个元素),并直接在片上内存中进行。

3. 基于块的范式

与 CUDA 线程的标量级思维不同,Triton 在块级别使用 SPMD(单程序多数据) 。你只需编写一个内核,Triton 就会在网格中启动多个实例。每个实例使用其 program_id 来计算它所拥有的“块”内存区域。

PyTorch 张量[元数据包装器]块 0(pid 0)块 1(pid 1)块 2(pid 2)

4. 环境设置

开始之前,请 在干净的环境中安装 Triton (使用 Conda 或 venv)以确保不会与现有的 CUDA 工具包发生依赖冲突: pip install triton

main.py
TERMINALbash — 80x24
> Ready. Click "Run" to execute.
>